package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.DCT

object DCTTest extends BaseTest {
    
    def apply(): BaseSpark = new DCTTest()
    
}

class DCTTest extends BaseSpark{
    
     override def getDataFrame(sparkSession: SparkSession = this.getSparkSession()): DataFrame = {
         sparkSession.createDataFrame(Seq(
              (0, Vectors.dense(0.0, 1.0, -2.0, 3.0)),
              (1, Vectors.dense(-1.0, 2.0, 4.0, -7.0)),
              (2, Vectors.dense(14.0, -2.0, -5.0, 1.0))
         ))
         .toDF("id", "features")
     }
    
    override def execute(dataFrame: DataFrame) = {
        //特征名称
        var feature = "features"
        var feature_new = "features_dtc"
        //设置模型
        val dct = new DCT()
        .setInputCol(feature)         //待变换的特征
        .setOutputCol(feature_new)    //变换后的特征名称
        .setInverse(false)            //true:执行反DCT,false:执行正向DCT.默认值:false
        //模型测试
        var transform = dct.transform(dataFrame)
        //show
        transform.show(100, 100)
        
        dataFrame.show(false)
        dataFrame.printSchema()
        transform.printSchema()
    }
    
}
